import hashlib
import pickle
from functools import partial
from os.path import expanduser

import pandas as pd
import hs_udata.apis.stock.api as api
from diskcache import Cache
from hs_udata.utils.datetime_func import convert_date
from ratelimiter import RateLimiter
from typing import Callable
import datetime as dt

from tqdm import tqdm

home = expanduser("~").replace("\\", "/")
cache = Cache(directory=home + "/.udata_cache", timeout=60, sqlite_synchronous=0)


def get_column_names(func: Callable) -> dict:
    """
    从函数文档中提取中英文映射
    :param func:
    :return:
    """
    if func.__doc__ is not None:
        params = func.__doc__.split("\n")
        rets = []
        begin = False
        for line in params:
            if "输出参数" in line:
                begin = True
                continue
            if "" == line and begin:
                break
            if begin:
                rets.append(line)
        return {(line.split(":")[1].split(" ")[-2]).strip(): (line.split(":")[2]).strip()[:-1] for line in rets}
    else:
        return {}


def get_func_info(func: Callable) -> dict:
    """
    从函数注释提取函数名称参数等信息
    :param func:
    :return:
    """
    rets = {
        "function": "",
        "input": "",
        "output": ""
    }
    if func.__doc__ is not None:
        params = func.__doc__.split("\n")

        cur = "function"
        for line in params:
            if "输出参数" in line:
                cur = "output"
                continue
            if "输入参数" in line:
                cur = "input"
                continue
            rets[cur] = rets[cur] + "\n" + line
    return rets


column_names = {func: get_column_names(getattr(api, func)) for func in dir(api) if not func.startswith("__")}
function_infos = {func: get_func_info(getattr(api, func)) for func in dir(api) if
                  not func.startswith("__") and getattr(api, func).__doc__ is not None and len(
                      getattr(api, func).__doc__.split("\n")) > 1}
df_func = pd.DataFrame(function_infos).T.reset_index()

api_rate_limit_dict = {func: RateLimiter(max_calls=120, period=60) for func in dir(api)}
default_rate_limit = RateLimiter(max_calls=120, period=60)


def to_cn(df: pd.DataFrame, name: str) -> pd.DataFrame:
    """
    根据函数名，解析函数注释中的中英文字段注释,把英文列名转化成中文列名
    :param df: hs_udata返回的DataFrame
    :param name: 函数名
    :return: 修改列名后的DataFrame
    """
    df.columns = [column_names[name][code] for code in df.columns]
    return df


def get_rate_limit(function_name: str) -> RateLimiter:
    """
    根据函数名返回对应的频率限制其
    :param function_name:
    :return:
    """
    if function_name in api_rate_limit_dict:
        return api_rate_limit_dict[function_name]
    else:
        return default_rate_limit


def _compute_key(args, kw):
    """
    序列化并求其哈希值
    """
    key = pickle.dumps((args, kw))
    return hashlib.sha1(key).hexdigest()


def proxy_api_en(name, *args, **kwargs):
    with get_rate_limit(name):
        result = getattr(api, name)(*args, **kwargs)
        return result


def proxy_api_cn(name, *args, **kwargs):
    return to_cn(proxy_api_en(name, *args, **kwargs))


class UDataProxy(object):
    """
    所有API返回中文结果
    """

    def __init__(self, language="cn"):
        """
        初始化代理
        :param language: cn/en
        :return:
        """
        self.language = language
        for name in dir(api):
            if name not in ['trading_calendar']:
                if self.language == "cn":
                    self.__dict__[name] = partial(proxy_api_cn, name)
                else:
                    self.__dict__[name] = partial(proxy_api_en, name)

    def __getattr__(self, name):
        if name not in self.__dict__:
            if self.language == "cn":
                self.__dict__[name] = partial(proxy_api_cn, name)
            else:
                self.__dict__[name] = partial(proxy_api_en, name)
            return self.__dict__[name]

    def trading_calendar(self,
                         start_date: str = None,
                         end_date: str = None,
                         ) -> pd.DataFrame:
        """
        获取指定时间的交易日历
        :param start_date:
        :param end_date:
        :return:
        """
        with cache:
            cached = cache.get("trading_calendar")
        if cached is None or cached['trading_date'].max() < dt.datetime.now().strftime("%Y-%m-%d") or (
                start_date is not None and convert_date(start_date) < cached['trading_date'].min().replace("-", "")):
            dates = pd.date_range(start=dt.datetime.now() - dt.timedelta(days=10000),
                                  end=dt.datetime.now() + dt.timedelta(days=750), freq='Y', closed='right')
            calendar_array = []
            for date in tqdm(dates):
                df_calendar = api.trading_calendar(start_date=(date - dt.timedelta(days=365)).strftime("%Y%m%d"),
                                                   end_date=date.strftime("%Y%m%d"))
                calendar_array.append(df_calendar)
            cached = pd.concat(calendar_array, ignore_index=True).drop_duplicates(subset=['trading_date'])
            with cache:
                cache.set("trading_calendar", cached, expire=86400 * 365)
        start_date = (dt.datetime.now() - dt.timedelta(days=365)).strftime(
            "%Y-%m-%d") if start_date is None else pd.to_datetime(start_date).strftime("%Y-%m-%d")
        end_date = (dt.datetime.now()).strftime(
            "%Y-%m-%d") if end_date is None else pd.to_datetime(end_date).strftime("%Y-%m-%d")
        df_ret = cached[(cached['trading_date'] >= start_date) & (cached['trading_date'] <= end_date)]
        if self.language == 'cn':
            return to_cn(df_ret, "trading_calendar")
        else:
            return df_ret

    @staticmethod
    def search_api(keyword: str = "") -> pd.DataFrame:
        """
        根据关键词，查询相关API
        :param keyword:
        :return:
        """
        result = df_func[(df_func['function'].str.contains(keyword)) | (df_func['index'].str.contains(keyword))]
        for row in result.itertuples():
            print(f"""--------------------------------------------------------------------------------
    函数名:{row.index}
    功  能:{row.function.strip()}
    参  数:\n    {row.input.strip()}""")
        return result


cn_proxy = UDataProxy(language="cn")
en_proxy = UDataProxy(language="en")
pd.set_option('display.max_rows', 200)

__all__ = ['UDataProxy', 'cn_proxy', 'en_proxy']
